Use MPReplayBuffer with Ray

Please see this document, too.

Ray is a OSS framework which enables users to build distributed application easily.

Ray can utilize multiple machines, so that Ray architecture doesn’t use shared memory except for immutable objects (Ref). However, if you use only a single machine, you might want to use shared memory for replay buffer to avoid expensive interprocess data sharing.

With cpprb 10.6+ and Python 3.8+, you can use MPReplayBuffer and MPPrioritizedReplayBuffer with Ray.

A key trick is to set authkey inside Ray process, which allows Ray workers to communicate with SyncManager process.

import base64
import multiprocessing as mp

import ray

ray.init()

# Encode base64 to avoid following error:
#   TypeError: Pickling an AuthenticationString object is disallowed for security reasons
encoded = base64.b64encode(mp.current_process().authkey)

def auth_fn(*args):
    mp.current_process().authkey = base64.b64decode(encoded)

ray.worker.global_worker.run_function_on_all_workers(auth_fn)

This trick overwrites process-wide authkey, which might confilict if you use other processes in it. Additionally, run_function_on_all_workers() is not Ray PublicAPI nor DeveloperAPI, so that it is possible that this trick won’t work in a future version.

We also have to select SharedMemory backend and SyncManager context. By this configuration, main data is placed on shared memory and synchronization objects (e.g. Lock and Event) are accessed through SyncManager proxy.

import multiprocessing as mp

from cpprb import MPReplayBuffer

buffer_size = 1e+6
m = mp.get_context().Manager()

rb = MPReplayBuffer(buffer_size, {"done": {}},
                    ctx = m, backend="SharedMemory")

In the end, (pseudo) example become like this;

# See: https://ymd_h.gitlab.io/cpprb/examples/mp_with_ray/

import base64
import multiprocessing as mp
import time

from cpprb import ReplayBuffer, MPPrioritizedReplayBuffer
import gym
import numpy as np
import ray


class Model:
    def __init__(self, env):
        self.env = env
        self.w = None

    def train(self, transitions):
        """
        Update model weights and return |TD|
        """
        absTD = np.zeros(shape=(transitions["obs"].shape[0],))
        # omit
        return absTD

    def __call__(self, obs):
        """
        Choose action from observation
        """
        # omit
        act = self.env.action_space.sample()
        return act

@ray.remote
def explorer(env_name, global_rb, env_dict, q, stop):
    try:
        buffer_size = 200

        local_rb = ReplayBuffer(buffer_size, env_dict)
        env = gym.make(env_name)
        model = Model(env)

        obs = env.reset()
        while not stop.is_set():
            if not q.empty():
                w = q.get()
                model.w = w

            act = model(obs)

            next_obs, rew, done, _ = env.step(act)

            local_rb.add(obs=obs, act=act, rew=rew, next_obs=next_obs, done=done)

            if done or local_rb.get_stored_size() == buffer_size:
                local_rb.on_episode_end()
                global_rb.add(**local_rb.get_all_transitions())
                local_rb.clear()
                obs = env.reset()
            else:
                obs = next_obs
    finally:
        stop.set()
    return None


def run():
    n_explorers = 4

    nwarmup = 100
    ntrain = int(1e+2)
    update_freq = 100
    env_name = "CartPole-v1"

    env = gym.make(env_name)

    buffer_size = 1e+6
    env_dict = {"obs": {"shape": env.observation_space.shape},
                "act": {},
                "rew": {},
                "next_obs": {"shape": env.observation_space.shape},
                "done": {}}
    alpha = 0.5
    beta = 0.4
    batch_size = 32


    ray.init()
    encoded = base64.b64encode(mp.current_process().authkey)
    def auth_fn(*args):
        mp.current_process().authkey = base64.b64decode(encoded)
    ray.worker.global_worker.run_function_on_all_workers(auth_fn)


    # `BaseContext.Manager()` automatically starts `SyncManager`
    # Ref: https://github.com/python/cpython/blob/3.9/Lib/multiprocessing/context.py#L49-L58
    m = mp.get_context().Manager()
    q = m.Queue()
    stop = m.Event()
    stop.clear()

    rb = MPPrioritizedReplayBuffer(buffer_size, env_dict, alpha=alpha,
                                   ctx=m, backend="SharedMemory")

    model = Model(env)

    explorers = []

    print("Start Explorers")
    for _ in range(n_explorers):
        explorers.append(explorer.remote(env_name, rb, env_dict, q, stop))


    print("Start Warmup")
    while rb.get_stored_size() < nwarmup and not stop.is_set():
        time.sleep(1)

    print("Start Training")
    for i in range(ntrain):
        if stop.is_set():
            break

        s = rb.sample(batch_size, beta)
        absTD = model.train(s)
        rb.update_priorities(s["indexes"], absTD)

        if i % update_freq == 0:
            q.put(model.w)
    print("Finish Training")


    stop.set()
    ray.get(explorers)

    m.shutdown()

if __name__ == "__main__":
    run()